Models with covarying parameters

Jesse Brunner

Example: Growing virus isolates

Recall from lab that exponentially growing virus, \(V\) should look like this: \[ V(t) = V_0 e^{rt} \] which, on the log scale, is a nice linear model: \[ \log[V(t)] = \log(V_0) + rt \] Our study:

  • ten strains of virus grown in cell culture
  • take samples on days 1, 3, and 5 and titer the virus
  • expect strains to vary in their growth rates
  • What if we expect that viruses that are very infectious (have a high \(V_0\) given a dose) have a low growth rate (\(r\)), and vice versa?.

Simulating our example

Need to simulate data where \(r\) and \(V_0\) are correlated for each strain.

r_mu <- 2.5 # mean growth rate
r_sig <- 0.75 # sd of growth rates

lV0_mu <- log(1e4) # mean initial virus population 
lV0_sig <- 1 # sd of initial virus population

rho <- -0.65 # correlation between lV0 and r

# vector of means
Mu <- c(lV0_mu, r_mu)
# vector of standard deviations
sigmas <- c(lV0_sig, r_sig)

How to deal with covariances in parameters?

We want parameters drawn from a joint distribution of \(r\) and \(\log(V_0)\) where high values in one are associated with low values in the other.

We need a matrix describing this covariance

Option 1 to build covariance matrix:

# covariance of the two parameters
(cov_lV0r <- r_sig * lV0_sig * rho)
[1] -0.4875
# covariance matrix
(Sigma <- matrix( c(lV0_sig^2, cov_lV0r, 
                    cov_lV0r, r_sig^2), 
                  ncol=2, byrow = TRUE))
        [,1]    [,2]
[1,]  1.0000 -0.4875
[2,] -0.4875  0.5625

covariance between parameters \[ \rho_{lV_0,r} = \sigma_r \times \sigma_{lV_0} \times \rho \\[15pt] \]

Covariance matrix \[ \begin{aligned} \boldsymbol{\Sigma} &= \begin{bmatrix} \text{var}(x) & \text{cov}(x,y) \\ \text{cov}(x,y) & \text{var}(y) \end{bmatrix} \\ &= \begin{bmatrix} \sigma^2_{lV_0} & \rho_{lV_0,r} \\ \rho_{lV_0,r} & \sigma^2_{r} \end{bmatrix} \end{aligned} \]

Option 2 to build covariance matrix:

# first make correlation matrix
(Rho <- matrix( c(1, rho, 
                  rho, 1), 
                ncol=2, byrow=TRUE))
      [,1]  [,2]
[1,]  1.00 -0.65
[2,] -0.65  1.00
# Then matrix multiply (%*%) to get covariance matrix
(Sigma <- diag(sigmas) %*% Rho %*% diag(sigmas))
        [,1]    [,2]
[1,]  1.0000 -0.4875
[2,] -0.4875  0.5625

Correlation matrix \[ \boldsymbol{\rho} = \begin{bmatrix} 1 & \rho \\ \rho & 1 \end{bmatrix} \] Covariance matrix \[ \begin{aligned} \boldsymbol{\Sigma} &= \boldsymbol{\sigma} \times \boldsymbol{\rho} \times \boldsymbol{\sigma} \\ &= \begin{bmatrix} \sigma & 0 \\ 0 & \sigma \end{bmatrix} \times \begin{bmatrix} 1 & \rho \\ \rho & 1 \end{bmatrix} \times \begin{bmatrix} \sigma & 0 \\ 0 & \sigma \end{bmatrix} \end{aligned} \]

Simulate parameters by strain

nS <- 20 # number of strains

params <- MASS::mvrnorm(n=nS, mu=Mu, Sigma=Sigma)
colnames(params) <- c("lV0", "r")
(params <- as.data.frame(params))
         lV0         r
1   9.452402 3.7275764
2   8.711203 2.6155409
3   9.791286 2.0081441
4  10.086890 2.3782600
5   9.640374 1.1717285
6  10.016767 2.7410814
7   9.312360 2.1956877
8   9.183789 2.6614482
9   7.293357 4.1114728
10  6.945774 3.0949708
11 10.593870 1.5277039
12 10.766241 2.0769321
13 10.124871 0.9640900
14  9.991686 0.9937701
15  9.564620 2.1413929
16  9.312226 3.3501344
17  8.464153 2.6831175
18 10.317536 2.0382876
19  9.901750 1.1130668
20  9.120904 2.3364231
plot(r ~ lV0, data = params)

Finally time to simulate observations!

time <- c(1,3,5) # time points
sigma <- 0.25 # observation error

df <- expand.grid(ID = 1:nS,
                  time = time)
df$lV<- rnorm(n=nrow(df), 
                 mean=params$lV0[df$ID] + params$r[df$ID]*df$time, 
                 sd=sigma)

cbind( head(df), tail(df))
  ID time       lV ID time       lV
1  1    1 13.24991 15    5 20.00314
2  2    1 11.19361 16    5 25.65776
3  3    1 11.73749 17    5 21.60064
4  4    1 12.34671 18    5 20.29491
5  5    1 10.88544 19    5 15.23627
6  6    1 12.99659 20    5 20.94581

Simulating our example: how to deal with covariances

Analysis goal

Our research questions are:

  1. Do virus strains have substantially different \(r\)’s and \(\log(V_0)\)’s
  2. Are the \(r\)’s and \(\log(V_0)\)’s negatively correlated?

How would you analyze these data?

Varying slopes & intercepts model

a & b vary by clusters, but are independent \[ \begin{align} \log(V_i) &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= a_{\text{Strain}[i]} + b_{\text{Strain}[i]]}\times \text{time} \\ a_{\text{Strain}} &\sim \text{Normal}(\mu_{a},\sigma_a) \\ b_{\text{Strain}} &\sim \text{Normal}(\mu_{b}, \sigma_b) \\ \mu_{a} &\sim \text{Normal}(4,1.5) \\ \sigma_{a} &\sim \text{Exponential}(2) \\ \mu_{b} &\sim \text{Normal}(1,1) \\ \sigma_{b} &\sim \text{Exponential}(3) \\ \sigma &\sim \text{Exponential}(2) \end{align} \]

a & b vary by clusters, but are correlated

\[ \begin{align} \log(V_i) &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= a_{\text{Strain}[i]} + b_{\text{Strain}[i]]}\times \text{time} \\ \left[\begin{matrix} a_{\text{Strain}} \\ b_{\text{Strain}} \end{matrix} \right] &\sim \text{MVNormal}\left(\left[\begin{matrix} \mu_{a} \\ \mu_{b} \end{matrix} \right],\Sigma \right) \\ \Sigma &= \left( \begin{matrix} \sigma_a & 0 \\ 0 & \sigma_b \end{matrix} \right) \text{Rho} \left( \begin{matrix} \sigma_a & 0 \\ 0 & \sigma_b \end{matrix} \right) \\ \mu_{a} &\sim \text{Normal}(4,1.5) \\ \sigma_{a} &\sim \text{Exponential}(2) \\ \mu_{b} &\sim \text{Normal}(1,1) \\ \sigma_{b} &\sim \text{Exponential}(3) \\ \text{Rho} &\sim \text{LKJcorr}(2) \\ \sigma &\sim \text{Exponential}(2) \end{align} \]

Varying slopes & intercepts model, in ulam()

a & b vary by clusters, but are independent

\[ \begin{align} \log(V_i) &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= a_{\text{Strain}[i]} + b_{\text{Strain}[i]]}\times \text{time} \\ a_{\text{Strain}} &\sim \text{Normal}(\mu_{a},\sigma_a) \\ b_{\text{Strain}} &\sim \text{Normal}(\mu_{b}, \sigma_b) \\ \mu_{a} &\sim \text{Normal}(4,1.5) \\ \sigma_{a} &\sim \text{Exponential}(2) \\ \mu_{b} &\sim \text{Normal}(1,1) \\ \sigma_{b} &\sim \text{Exponential}(3) \\ \sigma &\sim \text{Exponential}(2) \end{align} \]

m1 <- ulam(
  alist(
    lV~ dnorm(mu, sigma), 
    mu <- a[ID] + b[ID]*time,
    
    # priors
    a[ID] ~ dnorm(a_mu, a_sd),
    a_mu ~ dnorm(4,1.5),
    a_sd ~ dexp(2),
    
    b[ID] ~ dnorm(b_mu, b_sd),
    b_mu ~ dnorm(1,1),
    b_sd ~ dexp(3),
    
    sigma ~ dexp(2)
  ), data = df
)

Varying slopes & intercepts model, in ulam()

a & b vary by clusters, but are correlated

m2 <- ulam(
  alist(
    lV~ dnorm(mu, sigma), 
    mu <- a[ID] + b[ID]*time,
    
    # priors
    c(a, b)[ID] ~ multi_normal( c(a_mu, b_mu), Rho, sig_ID),
    a_mu ~ dnorm(4,1.5),
    b_mu ~ dnorm(1,1),
    sig_ID ~ dexp(2),
    
    Rho ~ lkj_corr(2),
    sigma ~ dexp(2)
  ), data = df
)

\[ \begin{align} \log(V_i) &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= a_{\text{Strain}[i]} + b_{\text{Strain}[i]]}\times \text{time} \\ \left[\begin{matrix} a_{\text{Strain}} \\ b_{\text{Strain}} \end{matrix} \right] &\sim \text{MVNormal}\left(\left[\begin{matrix} \mu_{a} \\ \mu_{b} \end{matrix} \right],\Sigma \right) \\ \Sigma &= \left( \begin{matrix} \sigma_a & 0 \\ 0 & \sigma_b \end{matrix} \right) \text{Rho} \left( \begin{matrix} \sigma_a & 0 \\ 0 & \sigma_b \end{matrix} \right) \\ \mu_{a} &\sim \text{Normal}(4,1.5) \\ \sigma_{a} &\sim \text{Exponential}(2) \\ \mu_{b} &\sim \text{Normal}(1,1) \\ \sigma_{b} &\sim \text{Exponential}(3) \\ \text{Rho} &\sim \text{LKJcorr}(2) \\ \sigma &\sim \text{Exponential}(2) \end{align} \]

Varying slopes & intercepts models: results

a & b vary by clusters, but are independent

precis(m1, depth=2)
# A tibble: 45 × 7
   par     mean     sd `5.5%` `94.5%`  rhat ess_bulk
   <chr>  <dbl>  <dbl>  <dbl>   <dbl> <dbl>    <dbl>
 1 a[1]   9.51  0.348   8.97   10.1   1.00      481.
 2 a[2]   8.78  0.345   8.29    9.37  0.999     444.
 3 a[3]   9.67  0.320   9.18   10.2   1.00      560.
 4 a[4]   9.75  0.316   9.24   10.2   1.00      372.
 5 a[5]   9.45  0.326   8.92    9.96  1.00      422.
 6 a[6]  10.1   0.341   9.58   10.6   0.999     467.
 7 a[7]   9.21  0.358   8.61    9.77  1.00      468.
 8 a[8]   9.12  0.328   8.63    9.67  1.00      494.
 9 a[9]   7.70  0.373   7.13    8.30  1.01      295.
10 a[10]  7.57  0.373   6.99    8.18  0.998     397.
11 a[11] 10.2   0.318   9.67   10.7   1.00      553.
12 a[12] 10.8   0.360  10.2    11.3   1.00      489.
13 a[13] 10.6   0.377  10.0    11.2   1.01      392.
14 a[14]  9.07  0.357   8.49    9.62  0.998     488.
15 a[15]  9.77  0.349   9.20   10.3   1.00      577.
16 a[16]  9.77  0.327   9.24   10.3   1.00      533.
17 a[17]  8.23  0.354   7.67    8.84  1.00      338.
18 a[18] 10.1   0.352   9.54   10.6   1.00      472.
19 a[19] 10.0   0.313   9.53   10.5   1.01      374.
20 a[20]  9.48  0.331   8.94    9.99  0.999     511.
21 a_mu   9.32  0.246   8.94    9.70  1.01      321.
22 a_sd   0.945 0.198   0.680   1.27  1.00      297.
23 b[1]   3.77  0.101   3.61    3.93  1.00      439.
24 b[2]   2.59  0.101   2.42    2.76  1.00      439.
25 b[3]   2.00  0.0979  1.85    2.15  0.999     575.
26 b[4]   2.50  0.0953  2.35    2.66  1.00      366.
27 b[5]   1.27  0.0979  1.11    1.43  1.00      450.
28 b[6]   2.75  0.103   2.59    2.92  1.00      496.
29 b[7]   2.22  0.108   2.05    2.40  1.00      530.
30 b[8]   2.69  0.0998  2.52    2.85  0.999     414.
31 b[9]   4.03  0.108   3.85    4.19  1.01      328.
32 b[10]  2.94  0.111   2.76    3.11  0.998     399.
33 b[11]  1.58  0.0973  1.43    1.74  1.00      489.
34 b[12]  2.11  0.107   1.94    2.28  1.00      409.
35 b[13]  0.866 0.113   0.703   1.06  1.01      421.
36 b[14]  1.24  0.109   1.06    1.41  1.01      473.
37 b[15]  2.06  0.104   1.90    2.23  1.00      486.
38 b[16]  3.16  0.0983  3.01    3.31  1.01      390.
39 b[17]  2.70  0.109   2.52    2.87  0.999     327.
40 b[18]  2.10  0.106   1.93    2.26  1.00      549.
41 b[19]  1.07  0.0916  0.936   1.22  1.01      372.
42 b[20]  2.27  0.105   2.10    2.43  0.998     491.
43 b_mu   2.25  0.199   1.93    2.56  1.00      911.
44 b_sd   0.845 0.138   0.658   1.08  1.00      733.
45 sigma  0.299 0.0539  0.229   0.397 0.999     150.

a & b vary by clusters, but are correlated

precis(m2, depth=3)
# A tibble: 49 × 7
   par         mean     sd `5.5%` `94.5%`   rhat ess_bulk
   <chr>      <dbl>  <dbl>  <dbl>   <dbl>  <dbl>    <dbl>
 1 b[1]       3.79  0.106   3.62    3.96   0.999     635.
 2 b[2]       2.62  0.0976  2.47    2.77   1.01      585.
 3 b[3]       1.99  0.100   1.83    2.15   1.01      886.
 4 b[4]       2.51  0.105   2.34    2.68   0.999     888.
 5 b[5]       1.24  0.0982  1.09    1.39   1.00      764.
 6 b[6]       2.75  0.0963  2.59    2.90   1.00      686.
 7 b[7]       2.21  0.119   2.03    2.39   1.00      998.
 8 b[8]       2.70  0.0999  2.54    2.86   1.02      827.
 9 b[9]       4.06  0.104   3.88    4.21   1.01      412.
10 b[10]      2.95  0.111   2.78    3.12   0.998     636.
11 b[11]      1.57  0.0991  1.42    1.73   0.999     717.
12 b[12]      2.11  0.0975  1.97    2.25   0.998     649.
13 b[13]      0.846 0.101   0.690   1.01   0.999     781.
14 b[14]      1.21  0.0906  1.07    1.35   1.00     1053.
15 b[15]      2.05  0.106   1.88    2.22   1.00      659.
16 b[16]      3.19  0.0932  3.04    3.32   0.998     799.
17 b[17]      2.71  0.108   2.53    2.88   0.999     496.
18 b[18]      2.10  0.0953  1.94    2.25   0.998     712.
19 b[19]      1.04  0.0938  0.896   1.19   1.01      778.
20 b[20]      2.27  0.0983  2.12    2.43   1.00      601.
21 a[1]       9.42  0.351   8.82    9.96   0.999     623.
22 a[2]       8.72  0.320   8.21    9.22   1.00      622.
23 a[3]       9.72  0.327   9.22   10.2    1.00      854.
24 a[4]       9.73  0.345   9.12   10.3    1.00      947.
25 a[5]       9.59  0.320   9.09   10.1    1.00      770.
26 a[6]      10.1   0.320   9.62   10.6    0.998     689.
27 a[7]       9.23  0.391   8.64    9.84   0.999     875.
28 a[8]       9.12  0.340   8.56    9.67   1.00      793.
29 a[9]       7.57  0.349   7.05    8.16   1.00      625.
30 a[10]      7.53  0.376   6.97    8.13   0.998     802.
31 a[11]     10.2   0.339   9.66   10.7    1.00      648.
32 a[12]     10.8   0.332  10.2    11.3    0.999     575.
33 a[13]     10.7   0.344  10.2    11.3    0.999     733.
34 a[14]      9.17  0.300   8.71    9.64   0.998     829.
35 a[15]      9.80  0.354   9.23   10.3    0.999     710.
36 a[16]      9.66  0.334   9.15   10.2    1.00      734.
37 a[17]      8.20  0.350   7.69    8.76   0.998     490.
38 a[18]     10.1   0.331   9.59   10.6    1.00      797.
39 a[19]     10.1   0.321   9.63   10.7    1.01      589.
40 a[20]      9.48  0.322   8.94    9.99   0.999     678.
41 a_mu       9.36  0.229   9.00    9.74   1.00      498.
42 b_mu       2.30  0.189   2.02    2.63   0.999     728.
43 sig_ID[1]  0.957 0.170   0.720   1.27   1.00      451.
44 sig_ID[2]  0.864 0.137   0.682   1.11   1.01      587.
45 Rho[1,1]   1     0       1       1     NA          NA 
46 Rho[2,1]  -0.458 0.184  -0.716  -0.146  1.02      553.
47 Rho[1,2]  -0.458 0.184  -0.716  -0.146  1.02      553.
48 Rho[2,2]   1     0       1       1     NA          NA 
49 sigma      0.291 0.0541  0.223   0.390  1.02      105.

Did we estimate the correlation correctly?

Recovering the parameters

red is estimate without correlation

blue is estimate with correlation

A non-centered versions

m3 <- ulam(
  alist(
    lV~ dnorm(mu, sigma), 
    mu <- (a_mu + alpha[ID, 1]) + (b_mu + alpha[ID, 2])*time,
    
    # priors
    # adaptive priors - non-centered
    transpars> matrix[ID, 2]:alpha <-
      compose_noncentered( sig_ID, L_Rho_ID, z_ID),
    
    matrix[2,ID]:z_ID ~ normal( 0 , 1 ),
    a_mu ~ dnorm(4,1.5),
    b_mu ~ dnorm(1,1),
    vector[2]:sig_ID ~ dexp(2),
    
    cholesky_factor_corr[2]:L_Rho_ID ~ lkj_corr_cholesky(2),
    sigma ~ dexp(2),
    
    # compute ordinary correlation matrices from Cholesky factors
    gq> matrix[2,2]:Rho_ID <<- Chol_to_Corr(L_Rho_ID)
  ), data = df
)
               mean        sd       5.5%      94.5%     rhat ess_bulk
Rho[1,1]  1.0000000 0.0000000  1.0000000  1.0000000       NA       NA
Rho[2,1] -0.4579968 0.1838399 -0.7159485 -0.1457048 1.020857 553.3067
Rho[1,2] -0.4579968 0.1838399 -0.7159485 -0.1457048 1.020857 553.3067
Rho[2,2]  1.0000000 0.0000000  1.0000000  1.0000000       NA       NA
                  mean        sd       5.5%      94.5%     rhat ess_bulk
Rho_ID[1,1]  1.0000000 0.0000000  1.0000000  1.0000000       NA       NA
Rho_ID[2,1] -0.4805687 0.1664392 -0.7061873 -0.1701685 1.005384 122.7034
Rho_ID[1,2] -0.4805687 0.1664392 -0.7061873 -0.1701685 1.005384 122.7034
Rho_ID[2,2]  1.0000000 0.0000000  1.0000000  1.0000000       NA       NA